package org.hong.monkey.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import org.apache.avro.{SchemaNormalization, Schema}
import org.apache.avro.Schema.Parser
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.io.{DecoderFactory, EncoderFactory, DatumReader, DatumWriter}
import org.apache.commons.io.IOUtils
import org.hong.monkey.{MonkeyException, MonkeyEnv}
import org.hong.monkey.io.CompressionCodec

import scala.collection.mutable

private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
  extends KSerializer[GenericRecord] {

  private val compressCache = new mutable.HashMap[Schema, Array[Byte]]()
  private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]()

  private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]()
  private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]()

  private val fingerprintCache = new mutable.HashMap[Schema, Long]()
  private val schemaCache = new mutable.HashMap[Long, Schema]()

  private lazy val codec = CompressionCodec.createCodec(MonkeyEnv.get.conf)

  def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, {
    val bos = new ByteArrayOutputStream()
    val out = codec.compressedOutputStream(bos)
    out.write(schema.toString.getBytes("UTF-8"))
    out.close()
    bos.toByteArray
  })

  def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, {
    val bis = new ByteArrayInputStream(schemaBytes.array())
    val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
    new Parser().parse(new String(bytes, "UTF-8"))
  })

  def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = {
    val encoder = EncoderFactory.get().binaryEncoder(output, null)
    val schema = datum.getSchema
    val fingerprint = fingerprintCache.getOrElseUpdate(schema, {
      SchemaNormalization.parsingFingerprint64(schema)
    })
    schemas.get(fingerprint) match {
      case Some(_) =>
        output.writeBoolean(true)
        output.writeLong(fingerprint)
      case None =>
        output.writeBoolean(true)
        val compressedSchema = compress(schema)
        output.writeInt(compressedSchema.length)
        output.writeBytes(compressedSchema)
    }

    writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema))
      .asInstanceOf[DatumWriter[R]]
      .write(datum, encoder)
    encoder.flush()
  }

  def deserializeDatum(input: KryoInput): GenericRecord = {
    val schema = {
      if (input.readBoolean()) {
        val fingerprint = input.readLong()
        schemaCache.getOrElseUpdate(fingerprint, {
          schemas.get(fingerprint) match {
            case Some(s) => new Parser().parse(s)
            case None =>
              throw new MonkeyException(
                "Error reading attempting to read avro data -- encountered an unknown " +
                  s"fingerprint: $fingerprint, not sure what schema to use.  This could happen " +
                  "if you registered additional schemas after starting your spark context.")
          }
        })
      } else {
        val length = input.readInt()
        decompress(ByteBuffer.wrap(input.readBytes(length)))
      }
    }
    val decoder = DecoderFactory.get.directBinaryDecoder(input, null)
    readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema))
      .asInstanceOf[DatumReader[GenericRecord]]
      .read(null, decoder)
  }

  override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit =
    serializeDatum(datum, output)

  override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord =
    deserializeDatum(input)
}
